/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
/*!
 * \file org_apache_tvm_native_c_api.cc
 * \brief tvm4j jni source file
 */
#include "org_apache_tvm_native_c_api.h"  // generated by javac
#ifdef TVM4J_ANDROID
#include "tvm_runtime.h"
#else
#include <dlfcn.h>
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/container/shape.h>
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/function.h>
#endif
#include <cstring>
#include <iostream>
#include <memory>
#include <thread>
#include <vector>

#include "jni_helper_func.h"

JavaVM* _jvm;
void* _tvmHandle = nullptr;

struct TVMFFIJVMStack {
  std::vector<tvm::ffi::AnyView> packed_args;
  // for later release
  std::vector<std::pair<jstring, const char*>> str_args;
  std::vector<std::pair<jbyteArray, std::unique_ptr<TVMFFIByteArray>>> byte_args;

  static TVMFFIJVMStack* ThreadLocal() {
    static thread_local TVMFFIJVMStack stack;
    return &stack;
  }
};

JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_nativeLibInit(JNIEnv* env, jobject obj,
                                                                 jstring jtvmLibFile) {
  if (_tvmHandle == NULL && !env->IsSameObject(jtvmLibFile, NULL)) {
    const char* tvmLibFile = env->GetStringUTFChars(jtvmLibFile, 0);
    _tvmHandle = dlopen(tvmLibFile, RTLD_LAZY | RTLD_GLOBAL);
    env->ReleaseStringUTFChars(jtvmLibFile, tvmLibFile);
    if (!_tvmHandle) {
      fprintf(stderr, "%s\n", dlerror());
      return 1;
    }
  }
  return env->GetJavaVM(&_jvm);
}

JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_shutdown(JNIEnv* env, jobject obj) {
  if (_tvmHandle) {
    dlclose(_tvmHandle);
  }
  return 0;
}

JNIEXPORT jstring JNICALL Java_org_apache_tvm_LibInfo_tvmFFIGetLastError(JNIEnv* env, jobject obj) {
  std::string err_msg = ::tvm::ffi::details::MoveFromSafeCallRaised().what();
  return env->NewStringUTF(err_msg.c_str());
}

// Function
JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionPushArgLong(JNIEnv* env,
                                                                             jobject obj,
                                                                             jlong arg) {
  TVMFFIJVMStack::ThreadLocal()->packed_args.emplace_back(static_cast<int64_t>(arg));
}

JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionPushArgDouble(JNIEnv* env,
                                                                               jobject obj,
                                                                               jdouble arg) {
  TVMFFIJVMStack::ThreadLocal()->packed_args.emplace_back(static_cast<double>(arg));
}

JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionPushArgString(JNIEnv* env,
                                                                               jobject obj,
                                                                               jstring arg) {
  jstring garg = reinterpret_cast<jstring>(env->NewGlobalRef(arg));
  const char* str = env->GetStringUTFChars(garg, 0);
  TVMFFIJVMStack* stack = TVMFFIJVMStack::ThreadLocal();
  stack->str_args.emplace_back(garg, str);
  stack->packed_args.emplace_back(str);
}

JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionPushArgHandle(JNIEnv* env,
                                                                               jobject obj,
                                                                               jlong arg,
                                                                               jint argTypeIndex) {
  TVMFFIJVMStack* stack = TVMFFIJVMStack::ThreadLocal();
  TVMFFIAny temp;
  temp.v_int64 = static_cast<int64_t>(arg);
  temp.type_index = static_cast<int>(argTypeIndex);
  temp.zero_padding = 0;
  stack->packed_args.emplace_back(tvm::ffi::AnyView::CopyFromTVMFFIAny(temp));
}

JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionPushArgDevice(JNIEnv* env,
                                                                               jobject obj,
                                                                               jobject arg) {
  jclass deviceClass = env->FindClass("org/apache/tvm/Device");
  jfieldID deviceTypeField = env->GetFieldID(deviceClass, "deviceType", "I");
  jfieldID deviceIdField = env->GetFieldID(deviceClass, "deviceId", "I");
  jint deviceType = env->GetIntField(arg, deviceTypeField);
  jint deviceId = env->GetIntField(arg, deviceIdField);
  TVMFFIJVMStack* stack = TVMFFIJVMStack::ThreadLocal();
  stack->packed_args.emplace_back(DLDevice{static_cast<DLDeviceType>(deviceType), deviceId});
}

JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionPushArgBytes(JNIEnv* env,
                                                                              jobject obj,
                                                                              jbyteArray arg) {
  jbyteArray garg = reinterpret_cast<jbyteArray>(env->NewGlobalRef(arg));
  jbyte* data = env->GetByteArrayElements(garg, 0);

  std::unique_ptr<TVMFFIByteArray> byteArray = std::make_unique<TVMFFIByteArray>();
  byteArray->size = static_cast<size_t>(env->GetArrayLength(garg));
  byteArray->data = reinterpret_cast<const char*>(data);

  TVMFFIJVMStack* stack = TVMFFIJVMStack::ThreadLocal();
  stack->packed_args.emplace_back(byteArray.get());
  stack->byte_args.emplace_back(garg, std::move(byteArray));
  // release (garg, data), byteArray later
}

JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionListGlobalNames(
    JNIEnv* env, jobject obj, jobject jfuncNames) {
  TVM_FFI_SAFE_CALL_BEGIN();
  jclass arrayClass = env->FindClass("java/util/List");
  jmethodID arrayAppend = env->GetMethodID(arrayClass, "add", "(Ljava/lang/Object;)Z");

  for (const auto& name : tvm::ffi::Function::ListGlobalNames()) {
    jstring jname = env->NewStringUTF(name.c_str());
    env->CallBooleanMethod(jfuncNames, arrayAppend, jname);
    env->DeleteLocalRef(jname);
  }

  env->DeleteLocalRef(arrayClass);
  TVM_FFI_SAFE_CALL_END();
}

JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionGetGlobal(JNIEnv* env, jobject obj,
                                                                           jstring jname,
                                                                           jobject jhandle) {
  const char* name = env->GetStringUTFChars(jname, 0);
  TVMFFIByteArray name_bytes{name, strlen(name)};
  TVMFFIObjectHandle handle;
  int ret = TVMFFIFunctionGetGlobal(&name_bytes, &handle);
  env->ReleaseStringUTFChars(jname, name);
  setLongField(env, jhandle, reinterpret_cast<jlong>(handle));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionCall(JNIEnv* env, jobject obj,
                                                                      jlong jhandle,
                                                                      jobject jretVal) {
  TVMFFIJVMStack* stack = TVMFFIJVMStack::ThreadLocal();
  TVMFFIAny ret_val;
  ret_val.type_index = tvm::ffi::TypeIndex::kTVMFFINone;
  ret_val.zero_padding = 0;
  ret_val.v_int64 = 0;
  int ret = TVMFFIFunctionCall(reinterpret_cast<TVMFFIObjectHandle>(jhandle),
                               reinterpret_cast<TVMFFIAny*>(stack->packed_args.data()),
                               stack->packed_args.size(), &ret_val);
  // release all temp resources
  for (auto& str_pair : stack->str_args) {
    env->ReleaseStringUTFChars(str_pair.first, str_pair.second);
    env->DeleteGlobalRef(str_pair.first);
  }

  for (auto& byte_pair : stack->byte_args) {
    env->ReleaseByteArrayElements(
        byte_pair.first, reinterpret_cast<jbyte*>(const_cast<char*>(byte_pair.second->data)), 0);
    env->DeleteGlobalRef(byte_pair.first);
  }
  stack->str_args.clear();
  stack->byte_args.clear();
  stack->packed_args.clear();

  // return TVMValue object to Java
  jclass refTVMValueCls = env->FindClass("org/apache/tvm/Base$RefTVMValue");
  jfieldID refTVMValueFid = env->GetFieldID(refTVMValueCls, "value", "Lorg/apache/tvm/TVMValue;");

  env->SetObjectField(jretVal, refTVMValueFid, tvmRetValueToJava(env, ret_val));
  env->DeleteLocalRef(refTVMValueCls);
  return ret;
}

// A helper object to take in JNIEnv ptr
// and allow automatic casting to both JNIEnv** and void**
// Background: different version of JDK may choose to have one signature
// or another for the case of AttachCurrentThread
// we use this universal helper object to enable compatibility with both
class JNIEnvPtrHelper {
 public:
  explicit JNIEnvPtrHelper(JNIEnv** penv) : penv_(penv) {}

  operator JNIEnv**() { return penv_; }

  operator void**() { return reinterpret_cast<void**>(penv_); }

 private:
  JNIEnv** penv_;
};

// Callback function
extern "C" int funcInvokeCallback(void* self, const TVMFFIAny* args, int num_args, TVMFFIAny* ret) {
  JNIEnv* env;
  int jniStatus = _jvm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6);
  if (jniStatus == JNI_EDETACHED) {
    _jvm->AttachCurrentThread(JNIEnvPtrHelper(&env), nullptr);
  } else {
    TVM_FFI_ICHECK(jniStatus == JNI_OK);
  }

  jclass tvmValueCls = env->FindClass("org/apache/tvm/TVMValue");
  jobjectArray jargs = env->NewObjectArray(num_args, tvmValueCls, 0);

  for (int i = 0; i < num_args; ++i) {
    TVMFFIAny arg = args[i];
    if (args[i].type_index >= tvm::ffi::TypeIndex::kTVMFFIRawStr) {
      TVMFFIAnyViewToOwnedAny(&args[i], &arg);
    }
    jobject jarg = tvmRetValueToJava(env, arg);
    env->SetObjectArrayElement(jargs, i, jarg);
  }

  jclass clsFunc = env->FindClass("org/apache/tvm/Function");
  jmethodID invokeRegisteredCbFunc = env->GetStaticMethodID(
      clsFunc, "invokeRegisteredCbFunc",
      "(Lorg/apache/tvm/Function$Callback;[Lorg/apache/tvm/TVMValue;)Ljava/lang/Object;");
  jmethodID pushArgToStack =
      env->GetStaticMethodID(clsFunc, "pushArgToStack", "(Ljava/lang/Object;)V");
  jobject jretValue = env->CallStaticObjectMethod(clsFunc, invokeRegisteredCbFunc,
                                                  reinterpret_cast<jobject>(self), jargs);

  // the stack
  TVMFFIJVMStack* stack = TVMFFIJVMStack::ThreadLocal();
  const size_t prev_num_str_args = stack->str_args.size();
  const size_t prev_num_bytes_args = stack->byte_args.size();

  // convert returned (java) TVMValue to (C) TVMValue
  env->CallStaticVoidMethod(clsFunc, pushArgToStack, jretValue);

  TVMFFIAny ret_val = stack->packed_args.back().CopyToTVMFFIAny();
  stack->packed_args.pop_back();
  TVMFFIAnyViewToOwnedAny(&ret_val, ret);

  // release allocated strings.
  if (stack->str_args.size() > prev_num_str_args) {
    const auto& pairArg = stack->str_args.back();
    env->ReleaseStringUTFChars(pairArg.first, pairArg.second);
    env->DeleteGlobalRef(pairArg.first);
    stack->str_args.pop_back();
  }
  // release allocated bytes.
  if (stack->byte_args.size() > prev_num_bytes_args) {
    const auto& pairArg = stack->byte_args.back();
    env->ReleaseByteArrayElements(
        pairArg.first, reinterpret_cast<jbyte*>(const_cast<char*>(pairArg.second->data)), 0);
    env->DeleteGlobalRef(pairArg.first);
    stack->byte_args.pop_back();
  }

  env->DeleteLocalRef(clsFunc);
  env->DeleteLocalRef(tvmValueCls);
  return 0;
}

// Free callback function
extern "C" void funcFreeCallback(void* resourceHandle) {
  JNIEnv* env;
  int jniStatus = _jvm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6);
  if (jniStatus == JNI_EDETACHED) {
    _jvm->AttachCurrentThread(JNIEnvPtrHelper(&env), nullptr);
  } else {
    TVM_FFI_ICHECK(jniStatus == JNI_OK);
  }
  env->DeleteGlobalRef(reinterpret_cast<jobject>(resourceHandle));
}

JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionCreateFromCallback(
    JNIEnv* env, jobject obj, jobject jfunction, jobject jretHandle) {
  TVMFFIObjectHandle out;
  int ret = TVMFFIFunctionCreate(reinterpret_cast<void*>(env->NewGlobalRef(jfunction)),
                                 funcInvokeCallback, funcFreeCallback, &out);
  setLongField(env, jretHandle, reinterpret_cast<jlong>(out));
  return ret;
}

JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIFunctionSetGlobal(JNIEnv* env, jobject obj,
                                                                           jstring jname,
                                                                           jlong jhandle,
                                                                           jint joverride) {
  const char* name = env->GetStringUTFChars(jname, 0);
  TVMFFIByteArray name_bytes{name, strlen(name)};
  int ret = TVMFFIFunctionSetGlobal(&name_bytes, reinterpret_cast<TVMFFIObjectHandle>(jhandle),
                                    reinterpret_cast<int>(joverride));
  env->ReleaseStringUTFChars(jname, name);
  return ret;
}

// Module
JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIObjectFree(JNIEnv* env, jobject obj,
                                                                    jlong jhandle) {
  return TVMFFIObjectDecRef(reinterpret_cast<TVMFFIObjectHandle>(jhandle));
}

// Tensor

JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIDLTensorGetShape(JNIEnv* env, jobject obj,
                                                                          jlong jhandle,
                                                                          jobject jshape) {
  DLTensor* array = reinterpret_cast<DLTensor*>(jhandle);
  int64_t* shape = array->shape;
  int ndim = array->ndim;

  // fill shape buffer
  jclass longClass = env->FindClass("java/lang/Long");
  jmethodID newLong = env->GetMethodID(longClass, "<init>", "(J)V");

  jclass arrayClass = env->FindClass("java/util/List");
  jmethodID arrayAppend = env->GetMethodID(arrayClass, "add", "(Ljava/lang/Object;)Z");
  for (int i = 0; i < ndim; ++i) {
    jobject data = env->NewObject(longClass, newLong, static_cast<jlong>(shape[i]));
    env->CallBooleanMethod(jshape, arrayAppend, data);
    env->DeleteLocalRef(data);
  }
  env->DeleteLocalRef(longClass);
  env->DeleteLocalRef(arrayClass);

  return 0;
}

JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIDLTensorCopyFromTo(JNIEnv* env,
                                                                            jobject obj,
                                                                            jlong jfrom,
                                                                            jlong jto) {
  TVM_FFI_SAFE_CALL_BEGIN();
  static auto fcopy_from_to = tvm::ffi::Function::GetGlobalRequired("runtime.TVMTensorCopyFromTo");
  fcopy_from_to(reinterpret_cast<DLTensor*>(jfrom), reinterpret_cast<DLTensor*>(jto));
  TVM_FFI_SAFE_CALL_END();
}

JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIDLTensorCopyFromJArray(JNIEnv* env,
                                                                                jobject obj,
                                                                                jbyteArray jarr,
                                                                                jlong jto) {
  TVM_FFI_SAFE_CALL_BEGIN();
  jbyte* pdata = env->GetByteArrayElements(jarr, NULL);
  DLTensor* to = reinterpret_cast<DLTensor*>(jto);
  size_t size = tvm::ffi::GetDataSize(*to);
  static auto fcopy_from_bytes =
      tvm::ffi::Function::GetGlobalRequired("runtime.TVMTensorCopyFromBytes");
  fcopy_from_bytes(to, static_cast<void*>(pdata), size);
  env->ReleaseByteArrayElements(jarr, pdata, 0);
  TVM_FFI_SAFE_CALL_END();
}

JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIDLTensorCopyToJArray(JNIEnv* env,
                                                                              jobject obj,
                                                                              jlong jfrom,
                                                                              jbyteArray jarr) {
  TVM_FFI_SAFE_CALL_BEGIN();
  DLTensor* from = reinterpret_cast<DLTensor*>(jfrom);
  size_t size = tvm::ffi::GetDataSize(*from);
  jbyte* pdata = env->GetByteArrayElements(jarr, NULL);
  static auto fcopy_to_bytes =
      tvm::ffi::Function::GetGlobalRequired("runtime.TVMTensorCopyToBytes");
  fcopy_to_bytes(from, static_cast<void*>(pdata), size);
  env->ReleaseByteArrayElements(jarr, static_cast<jbyte*>(pdata),
                                0);  // copy back to java array automatically
  TVM_FFI_SAFE_CALL_END();
}

JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmSynchronize(JNIEnv* env, jobject obj,
                                                                  jint jdeviceType,
                                                                  jint jdeviceId) {
  TVM_FFI_SAFE_CALL_BEGIN();
  static auto fsync = tvm::ffi::Function::GetGlobalRequired("runtime.Device_StreamSync");
  DLDevice device{static_cast<DLDeviceType>(jdeviceType), jdeviceId};
  fsync(device, nullptr);
  TVM_FFI_SAFE_CALL_END();
}

JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmTensorEmpty(
    JNIEnv* env, jobject obj, jlongArray jshape, jint jdtypeCode, jint jdtypeBits, jint jdtypeLanes,
    jint jdeviceType, jint jdeviceId, jobject jret) {
  TVM_FFI_SAFE_CALL_BEGIN();
  int ndim = static_cast<int>(env->GetArrayLength(jshape));
  jlong* shapeArray = env->GetLongArrayElements(jshape, NULL);
  tvm::ffi::Shape shape(shapeArray, shapeArray + ndim);
  DLDataType dtype;
  dtype.code = static_cast<uint8_t>(jdtypeCode);
  dtype.bits = static_cast<uint8_t>(jdtypeBits);
  dtype.lanes = static_cast<int16_t>(jdtypeLanes);
  DLDevice device{static_cast<DLDeviceType>(jdeviceType), jdeviceId};
  env->ReleaseLongArrayElements(jshape, shapeArray, 0);
  static auto fempty = tvm::ffi::Function::GetGlobalRequired("runtime.TVMTensorAllocWithScope");
  tvm::ffi::Tensor out = fempty(shape, dtype, device, nullptr).cast<tvm::ffi::Tensor>();
  void* handle = tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(out));
  setLongField(env, jret, reinterpret_cast<jlong>(handle));
  TVM_FFI_SAFE_CALL_END();
}
